// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "base_reference_test.hpp"
#include "inference_engine.hpp"

namespace reference_tests {

class ReferenceCNNTest {
public:
    ReferenceCNNTest();

    void Exec();

    void LoadNetwork();
    void FillInputs(); // Both for legacy and for OV2.0 API
    void Infer();

    void LoadNetworkLegacy();
    void InferLegacy();

    virtual void Validate();

protected:
    const std::string targetDevice;
    std::shared_ptr<ov::Model> function;
    InferenceEngine::CNNNetwork legacy_network;

    float threshold = 1e-5f;    // Relative diff
    float abs_threshold = -1.f; // Absolute diff (not used when negative)

    std::vector<ov::runtime::Tensor> outputs_ov20;
    std::vector<ov::runtime::Tensor> outputs_legacy;

protected:
    // These will be generated by default, if user has not specified inputs manually
    std::vector<ov::runtime::Tensor> inputData;
    InferenceEngine::BlobMap legacy_input_blobs;

private:
    std::shared_ptr<ov::runtime::Core> core;
    ov::runtime::ExecutableNetwork executableNetwork;
    ov::runtime::InferRequest inferRequest;

    std::shared_ptr<InferenceEngine::Core> legacy_core;
    InferenceEngine::ExecutableNetwork legacy_exec_network;
    InferenceEngine::InferRequest legacy_infer_request;
};

}  // namespace reference_tests
